#ifndef SRC_TOPK_H_
#define SRC_TOPK_H_

#include <cstdlib>
#include <cstdint>
#include <cmath>
#include <algorithm>
#include <vector>
#include <tuple>
#include <random>
#include "countsketch.h"
#include "logistic.h"
#include "logistic_sketch.h"
#include "jl_recovery_sketch.h"
#include "black_box_reduction.h"
#include "heap.h"

namespace wmsketch {

class TopKFeatures {
 protected:
  uint32_t k_;
  TopKHeap<uint32_t> heap_;

  explicit TopKFeatures(uint32_t k): k_{k}, heap_(k) { }
  TopKFeatures(uint32_t k, int32_t seed, float pow = 1.f): k_{k}, heap_(k) { }

 public:
  virtual ~TopKFeatures() = default;
  virtual void topk(std::vector<std::pair<uint32_t, float> >& out) {
    heap_.items(out);
    std::sort(out.begin(), out.end(),
        [](auto& a, auto& b) { return fabs(a.second) > fabs(b.second); });
  }
  virtual bool predict(const std::vector<std::pair<uint32_t, float> >& x) = 0;
  virtual bool update(const std::vector<std::pair<uint32_t, float> >& x, bool label) = 0;
  virtual float bias() {
    return 0.f;
  }
};

class LogisticTopK : public TopKFeatures {
 private:
  LogisticRegression lr_;
  std::vector<float> new_weights_;

 public:
  LogisticTopK(uint32_t k, uint32_t dim, float lr_init, float l2_reg, bool no_bias);
  ~LogisticTopK() override;
  bool predict(const std::vector<std::pair<uint32_t, float> >& x) override;
  bool update(const std::vector<std::pair<uint32_t, float> >& x, bool label) override;
  float bias() override;
};

class LogisticSketchTopK : public TopKFeatures {
 private:
  LogisticSketch sk_;
  std::vector<float> new_weights_;
  std::vector<uint32_t> idxs_;
  uint64_t t_;

 public:
  LogisticSketchTopK(
      uint32_t k,
      uint32_t log2_width,
      uint32_t depth,
      int32_t seed,
      float lr_init = 0.1,
      float l2_reg = 1e-3,
      bool median_update = false);
  ~LogisticSketchTopK();
  void topk(std::vector<std::pair<uint32_t, float> >& out);
  bool predict(const std::vector<std::pair<uint32_t, float> >& x);
  bool update(const std::vector<std::pair<uint32_t, float> >& x, bool label);
  float bias();

 private:
  void refresh_heap();
};


class JLRecoverySketchTopK : public TopKFeatures {

private:
	JLRecoverySketch sk_;
	std::vector<float> new_weights_;
	std::vector<uint32_t> idxs_;
	uint64_t t_;

public:
	JLRecoverySketchTopK(
		uint32_t k,
		uint32_t log2_width,
		uint32_t depth,
		int32_t seed,
		float lr_init = 0.1,
		float l2_reg = 1e-3);
	~JLRecoverySketchTopK();
	void topk(std::vector<std::pair<uint32_t, float>>& out);
	bool predict(const std::vector<std::pair<uint32_t, float>>& x);
	bool update(const std::vector<std::pair<uint32_t, float>>& x, bool label);
	float bias();

private:
	void refresh_heap();

};

class BlackBoxReductionTopK : public TopKFeatures {

private:
	BlackBoxReduction sk_;
	std::vector<float> new_weights_;
	std::vector<uint32_t> idxs_;
	uint64_t t_;

public:
	BlackBoxReductionTopK(
			uint32_t k,
			uint32_t log2_width,
			uint32_t depth,
			int32_t seed,
			float lr_init = 0.1,
			float l2_reg = 1e-3);
	~BlackBoxReductionTopK();
	void topk(std::vector<std::pair<uint32_t, float>>& out);
	bool predict(const std::vector<std::pair<uint32_t, float>>& x);
	bool update(const std::vector<std::pair<uint32_t, float>>& x, bool label);
	float bias();

private:
	void refresh_heap();
};

} // namespace wmsketch

#endif /* SRC_TOPK_H_ */
